SNN-CLAPP¶

In [ ]:
import matplotlib.pyplot as plt
from utils import  test_classwise, train_samplewise_clapp
from data import load_SHD
from model import CLAPP_RSNN
import numpy as np
import torch
import seaborn as sns
from scipy.signal import savgol_filter
color_list = sns.color_palette('hls', 20)
device = 'cpu'
epochs = 1
n_inputs = 700 # 28*28 #34 * 34 * 2
n_hidden = 5 * [512]
n_outputs = 20
batch_size = 64
folder = 'models/'
model_name = folder + 'shd_5layer_norec.pt'

Dataset¶

Spiking Heidelberg Digits

In [ ]:
#train_loader, test_loader = load_PMNIST(n_time_bins, scale=0.9, patches=True) #load_NMNIST(n_time_bins, batch_size=batch_size)
n_time_bins = 100
train_loader, test_loader = load_SHD(batch_size=batch_size) #load_NMNIST(n_time_bins, batch_size=batch_size)
# Plot Example
for i in range(3):
    frames, target = train_loader.next_item(-1, contrastive=True)
    plt.figure()
    plt.imshow(frames.squeeze(1).T)
    plt.colorbar()
    print(frames.shape, target)
/home/lars/ownCloud/ETH/Master/Project_2/SNN_CLAPP/data.py:17: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.y = torch.tensor(y)
torch.Size([100, 1, 700]) tensor([4.])
torch.Size([100, 1, 700]) tensor([10.])
torch.Size([100, 1, 700]) tensor([6.])

Load pretrained model¶

In [ ]:
SNN = CLAPP_RSNN(n_inputs, n_hidden, n_outputs, beta=0.95, out_proj=False, device=device, recurrent=False, cat=False).to(device)
SNN.load_state_dict(torch.load(model_name, map_location='cpu'))
from_epoch = 5
train_clapp_loss = torch.load(model_name[:-3]+'_clapp_loss.pt', map_location='cpu')[int(from_epoch*len(train_loader)/batch_size):]
print(train_clapp_loss.shape)
for i in range(train_clapp_loss.shape[-1]):
    plt.plot(from_epoch+(batch_size*np.arange(train_clapp_loss.shape[0])/len(train_loader)), savgol_filter(train_clapp_loss[:,i], 99, 1))
plt.legend([f'layer {i+1}' for i in range(len(SNN.clapp))])
# plt.ylim([-3.,-2])
plt.xlabel('Epoch')
plt.ylabel('Clapp Loss')
# train_shd_supervised_clapp(SNN, train_loader, 1, 'cpu')
torch.Size([101313, 5])
Out[ ]:
Text(0, 0.5, 'Clapp Loss')
In [ ]:
clapp_activation, target_list, clapp_losses = test_classwise(SNN, test_loader, device, batch_size=batch_size, temporal=True)
In [ ]:
print(f'Mean CLAPP Loss: {torch.stack(clapp_losses).mean(axis=0)}')
plt.plot(torch.stack(clapp_losses).detach().cpu()[:100,:])
print(torch.stack(clapp_losses).mean(axis=0))
Mean CLAPP Loss: tensor([-113.9548, -159.1156, -196.8741, -210.1194, -227.6216])
tensor([-113.9548, -159.1156, -196.8741, -210.1194, -227.6216])

Analyze Weights Directly¶

In [ ]:
layers = [SNN.clapp[0].fc.weight[:,:n_inputs]]
for i in range(1, len(SNN.clapp)):
    layers.append(SNN.clapp[i].fc.weight[:,:n_hidden[i-1]] @ layers[-1])

for i in range(len(SNN.clapp)):
    plt.figure()
    plt.imshow(SNN.clapp[i].fc.weight.detach(), vmax=0.05, vmin=-0.05)
    plt.colorbar()
    # plt.figure()
    # plt.imshow(SNN.clapp[i].pred.weight.detach(), vmax=0.5, vmin=-0.5)
    # plt.colorbar()
for lay in layers:
    plt.figure()
    plt.imshow(lay.detach())
    plt.colorbar()
In [ ]:
print(len(clapp_activation))
hidden_activities_transformed = torch.stack(clapp_activation).swapaxes(0,1).reshape(len(SNN.clapp), -1, 512)
target_transformed = torch.stack(target_list).flatten()
print(hidden_activities_transformed.shape, target_transformed.shape)

from sklearn.decomposition import PCA
from umap import UMAP
from sklearn.manifold import TSNE
reduction = TSNE
colors = [color_list[i.int()] for i in target_transformed]
for i, hat in enumerate(hidden_activities_transformed):
    reduct = reduction(n_components=2)
    # hat_diff = hat[1:] - hat[:-1]
    print(hat.shape)
    hat_transform = reduct.fit_transform(hat.detach().cpu().numpy())
    print(hat_transform.shape)
    print(f'Total spikes in layer {i}: {hat.sum()}')
    plt.figure()
    plt.imshow(hat, vmax=15)
    plt.colorbar()
    plt.figure(figsize=(8,8))
    plt.title(f'Hidden Activations {i}')
    col = colors
    for i in range(train_loader.num_classes):
        col_indeces = np.argwhere(target_transformed.squeeze() == i).squeeze()
        hat_col = hat_transform[col_indeces, :]
        plt.scatter(hat_col[:,0], hat_col[:,1], s=6, color=color_list[i], label=i, alpha=1)
    plt.legend()
36
torch.Size([5, 2304, 512]) torch.Size([2304])
torch.Size([2304, 512])
(2304, 2)
Total spikes in layer 0: 4584499.0
torch.Size([2304, 512])
(2304, 2)
Total spikes in layer 1: 4654315.0
torch.Size([2304, 512])
(2304, 2)
Total spikes in layer 2: 5770904.0
torch.Size([2304, 512])
(2304, 2)
Total spikes in layer 3: 5382842.0
torch.Size([2304, 512])
(2304, 2)
Total spikes in layer 4: 6065976.0

Train output Projection¶

In [ ]:
from model import CLAPP_out
from tqdm.notebook import tqdm
def train_out_proj(epochs, batch):
    # train output projections from all layers (and no layer)
    losses_out = []
    optimizers = []
    out_projs = []
    print_interval = 10*batch
    # SNN.out_proj.out_proj.reset_parameters()
    out_proj_0 = CLAPP_out(700, 20, beta=0.95)
    optim_0 = torch.optim.SGD(out_proj_0.parameters(), lr=1e-4)
    for lay in range(len(SNN.clapp)):
        out_projs.append(CLAPP_out(512, 20, beta=0.95))
        optimizers.append(torch.optim.SGD(out_projs[-1].parameters(), lr=1e-4))
        optimizers[-1].zero_grad()
    SNN.eval()
    target = batch*[0]
    acc = []
    correct = (len(SNN.clapp) + 1)*[0]
    with torch.no_grad():
        pbar = tqdm(total=len(train_loader)*epochs)
        while len(losses_out)*batch < len(train_loader)*epochs:
            data, target = train_loader.next_item(target, contrastive=True)
            SNN.reset(0)
            logit_lists = [[] for lay in range(len(SNN.clapp)+1)]
            data = data.squeeze()
            for step in range(data.shape[0]):
                data_step = data[step].float().to(device)
                target = target.to(device)
                logits, mem_his, clapp_loss = SNN(data_step, target, 0)
                logts, _ = out_proj_0(data_step, target)
                logit_lists[0].append(logts)
                for lay in range(len(SNN.clapp)):
                    logts, _ = out_projs[lay](logits[lay], target)
                    logit_lists[lay+1].append(logts)
            
            preds = [torch.stack(logit_lists[lay]).sum(axis=0) for lay in range(len(SNN.clapp)+1)]
            # if pred.max() < 1: print(pred.max())
            dL = [preds[lay].argmax(axis=-1) == target for lay in range(len(SNN.clapp)+1)]
            out_proj_0.reset(1-dL[0].float())
            for i, out_proj in enumerate(out_projs):
                out_proj.reset(1-dL[i+1].float())

            correct = [correct[lay] + dL[lay].sum() for lay in range(len(SNN.clapp)+1)]
            losses_out.append(torch.tensor([torch.nn.functional.cross_entropy(preds[lay], target.squeeze().long()) for lay in range(len(SNN.clapp)+1)], requires_grad=False))

            optim_0.step()
            optim_0.zero_grad()
            for opt in optimizers:
                opt.step()
                opt.zero_grad()
            
            if len(losses_out)*batch % print_interval == 0:
                pbar.write(f'Cross Entropy Loss: {(torch.stack(losses_out)[-400//batch:].sum(dim=0)/(400//batch)).numpy()}\n' +
                           f'Correct: {100*np.array(correct)/print_interval}%')
                acc.append(np.array(correct)/print_interval)
                correct = (len(SNN.clapp) + 1)*[0]
            pbar.update(batch)
    return [out_proj_0, *out_projs], np.asarray(acc), torch.stack(losses_out)

with torch.no_grad():
    out_projs, acc, losses_out = train_out_proj(10, 80)
  0%|          | 0/81560 [00:00<?, ?it/s]
Cross Entropy Loss: [3.0592885 3.304688  3.2149575 3.2233822 2.3872805 3.4384217]
Correct: [ 2.875 11.375 18.    15.75  27.    15.   ]%
Cross Entropy Loss: [3.0628736 3.1732297 2.8184757 2.1746995 1.9051759 2.5545602]
Correct: [ 4.5   17.875 25.875 41.5   55.75  36.375]%
Cross Entropy Loss: [3.1697598 2.983493  2.0265057 1.8123022 1.6135648 2.02384  ]
Correct: [ 5.625 24.25  40.75  52.75  58.75  45.75 ]%
Cross Entropy Loss: [3.1936996 2.6097703 2.0575373 1.5988777 1.4927404 1.4794581]
Correct: [ 4.875 26.    45.25  59.125 60.5   61.5  ]%
Cross Entropy Loss: [3.0790238 2.4289439 1.6033024 1.3053102 1.1618885 1.264855 ]
Correct: [ 7.375 30.25  47.125 60.25  65.125 69.   ]%
Cross Entropy Loss: [3.1002088  2.2225099  1.4810988  1.0873731  0.81291664 1.113727  ]
Correct: [ 8.625 33.625 51.125 62.75  72.875 72.625]%
Cross Entropy Loss: [3.165258  2.046066  1.1968672 0.7255853 0.6862627 1.0534015]
Correct: [ 7.875 39.    59.875 71.5   76.625 73.875]%
Cross Entropy Loss: [3.1344364  2.179806   1.4766083  0.94342005 0.7400807  1.0225474 ]
Correct: [ 7.5   37.875 53.75  71.625 78.5   73.   ]%
Cross Entropy Loss: [3.1401048 2.1995368 1.4009577 0.827705  0.5798048 0.8583808]
Correct: [ 8.    36.5   54.75  70.875 79.25  73.875]%
Cross Entropy Loss: [3.3764114  2.0199063  1.2509757  0.75755894 0.60538054 0.868153  ]
Correct: [ 9.25 38.   56.75 72.5  79.25 76.  ]%
Cross Entropy Loss: [3.0230844  2.1256597  1.0526004  0.68123925 0.61539155 1.0028805 ]
Correct: [11.625 40.625 63.125 77.25  80.875 75.5  ]%
Cross Entropy Loss: [3.0273972  2.1773567  1.1713005  0.7561097  0.63606775 0.8772826 ]
Correct: [ 8.5   34.25  61.5   75.75  79.5   76.125]%
Cross Entropy Loss: [3.2097042  2.0563214  1.1072843  0.9506593  0.80205727 0.931341  ]
Correct: [ 9.625 39.75  62.5   72.875 79.    77.5  ]%
Cross Entropy Loss: [3.0743937  2.1993039  1.097712   0.6858389  0.5766239  0.80621386]
Correct: [ 9.125 37.5   63.75  74.75  80.875 80.5  ]%
Cross Entropy Loss: [3.0671723 2.0141437 1.1794173 0.7060026 0.5225506 0.5934661]
Correct: [ 9.75  38.5   59.625 75.125 81.625 82.5  ]%
Cross Entropy Loss: [2.9994457 2.1309588 1.1689099 0.65776   0.5776539 0.7486261]
Correct: [ 9.125 39.5   63.625 78.625 82.625 80.75 ]%
Cross Entropy Loss: [3.0490613  1.7696857  0.87877256 0.6178111  0.5513927  0.79168826]
Correct: [ 9.125 41.5   62.    77.375 79.75  80.   ]%
Cross Entropy Loss: [3.0263753  1.8986218  1.0611737  0.69556695 0.6273117  0.7438471 ]
Correct: [ 9.75  41.75  65.625 75.    80.875 80.25 ]%
Cross Entropy Loss: [3.0008912  2.007826   0.98661995 0.6089993  0.5939046  0.6744814 ]
Correct: [10.125 38.    67.5   77.75  82.375 82.125]%
Cross Entropy Loss: [3.1105425  1.6115786  0.9665772  0.5954911  0.5312096  0.57376426]
Correct: [10.5   44.25  63.75  77.75  81.625 80.5  ]%
Cross Entropy Loss: [3.0068498  2.02105    0.8766573  0.6144441  0.60451126 0.6416334 ]
Correct: [ 9.75  41.    64.    77.25  81.    81.375]%
Cross Entropy Loss: [3.160195   1.950182   0.9932415  0.68624324 0.58092666 0.7089085 ]
Correct: [ 8.625 38.25  64.5   78.25  84.    83.875]%
Cross Entropy Loss: [3.147021   1.6818091  0.97858095 0.6021178  0.50018394 0.4438458 ]
Correct: [ 8.625 43.875 67.25  78.875 82.625 84.625]%
Cross Entropy Loss: [3.0289884  1.6315044  1.0720955  0.6997339  0.64312404 0.6109128 ]
Correct: [ 9.375 44.875 65.875 77.375 81.25  83.75 ]%
Cross Entropy Loss: [3.1076818  1.7735803  0.9963292  0.6965545  0.5888001  0.67020994]
Correct: [ 9.75  39.25  68.625 78.5   81.125 83.375]%
Cross Entropy Loss: [3.0140233  1.7964083  1.0170692  0.71051633 0.6399194  0.62391514]
Correct: [ 7.25  44.125 64.625 78.5   80.    82.5  ]%
Cross Entropy Loss: [3.2983482  1.4336803  0.8651347  0.59745806 0.59583783 0.47938347]
Correct: [ 7.75  49.    69.625 79.75  81.625 85.5  ]%
Cross Entropy Loss: [3.0227172 1.6308626 0.8873695 0.5688719 0.5236004 0.4128879]
Correct: [10.    44.5   66.375 78.625 80.625 85.   ]%
Cross Entropy Loss: [3.1221206  1.8216858  1.086255   0.5396703  0.4476328  0.45825785]
Correct: [10.375 43.625 62.75  80.125 83.5   87.   ]%
Cross Entropy Loss: [3.0165932  1.594445   0.82931024 0.62122405 0.56433904 0.53335106]
Correct: [10.75  45.875 69.25  76.625 83.125 84.   ]%
Cross Entropy Loss: [2.976345   1.7552588  0.8727827  0.5439472  0.487845   0.48407856]
Correct: [ 8.625 40.75  68.875 80.875 82.875 85.375]%
Cross Entropy Loss: [3.1561894 2.109057  1.0499103 0.649124  0.6493728 0.6012848]
Correct: [10.75  39.375 63.    81.75  81.875 84.5  ]%
Cross Entropy Loss: [3.0022197  1.6594759  0.88983524 0.6084543  0.4750777  0.63930285]
Correct: [ 8.875 40.625 65.375 80.75  83.    84.625]%
Cross Entropy Loss: [3.0250363  1.6357654  0.7875086  0.5366533  0.4635105  0.38862798]
Correct: [ 8.25  43.875 69.875 79.75  84.5   85.625]%
Cross Entropy Loss: [3.0546546  1.8087715  1.1468611  0.7362593  0.52866495 0.54879266]
Correct: [10.5   44.375 58.875 74.125 82.5   85.75 ]%
Cross Entropy Loss: [2.9471247  1.7242441  0.8533147  0.6239732  0.46459252 0.4667224 ]
Correct: [ 8.75  43.    69.625 79.375 83.75  85.625]%
Cross Entropy Loss: [2.9403381  2.424672   1.1020358  0.64152163 0.48979133 0.54925853]
Correct: [ 9.75  39.75  65.625 78.5   85.    84.75 ]%
Cross Entropy Loss: [2.9274268  1.7775714  1.0676163  0.7637645  0.59661067 0.696983  ]
Correct: [12.5   47.125 67.    76.125 83.25  82.375]%
Cross Entropy Loss: [3.0133202 2.0769572 1.0828611 0.6973138 0.6013626 0.6315218]
Correct: [13.375 41.875 62.5   74.75  82.375 83.75 ]%
Cross Entropy Loss: [3.2284603  1.7134516  0.9874684  0.72427064 0.5652125  0.6226713 ]
Correct: [12.    41.125 67.25  74.625 81.125 83.125]%
Cross Entropy Loss: [3.4371686  1.851958   0.8499395  0.51909226 0.45415252 0.37409484]
Correct: [12.75  38.25  67.5   78.625 80.625 84.   ]%
Cross Entropy Loss: [3.0054333  1.6608994  0.8732424  0.6181227  0.50393623 0.5524405 ]
Correct: [11.5  46.25 72.25 80.75 86.   84.75]%
Cross Entropy Loss: [3.0089493  1.8744663  1.0633396  0.6948918  0.62322265 0.69205284]
Correct: [10.375 42.5   68.625 77.75  82.25  83.   ]%
Cross Entropy Loss: [3.0932999  1.7876904  0.9634226  0.5683807  0.48349482 0.42919135]
Correct: [ 9.125 43.    67.5   80.625 85.25  86.5  ]%
Cross Entropy Loss: [3.19741    2.1482918  0.9270364  0.50343615 0.49512953 0.43248376]
Correct: [ 9.875 39.625 64.375 82.125 84.    85.875]%
Cross Entropy Loss: [2.9276383  1.4919791  0.8011287  0.53904796 0.465405   0.44272572]
Correct: [13.875 48.875 70.5   79.625 84.375 87.25 ]%
Cross Entropy Loss: [3.0575027 1.6998974 1.2657552 0.8699988 0.6387011 0.5657667]
Correct: [14.    49.375 64.125 77.25  81.875 84.875]%
Cross Entropy Loss: [3.0627096  1.6546942  1.0380795  0.6233481  0.54562175 0.5222074 ]
Correct: [14.5   44.75  69.    81.375 83.25  84.625]%
Cross Entropy Loss: [2.8974407  1.957535   1.0858295  0.80741656 0.6602985  0.5484301 ]
Correct: [11.125 42.25  65.25  75.875 81.5   83.   ]%
Cross Entropy Loss: [3.4451847  1.9205099  0.8657915  0.4949943  0.4703858  0.42402235]
Correct: [11.875 48.875 71.5   82.625 84.25  87.25 ]%
Cross Entropy Loss: [2.9432416 2.0023696 0.7916729 0.5706314 0.5208697 0.5321156]
Correct: [12.5   38.625 68.    82.5   83.5   84.875]%
Cross Entropy Loss: [3.030706   1.6219174  0.881166   0.56137925 0.49297038 0.48670998]
Correct: [ 9.875 46.625 67.5   78.5   81.125 85.5  ]%
Cross Entropy Loss: [3.1968966  1.8695322  1.1020805  0.77584946 0.5797532  0.51994276]
Correct: [10.25  38.25  61.625 75.875 80.625 83.125]%
Cross Entropy Loss: [3.040295   1.6315901  0.8334408  0.6557951  0.50152415 0.61000997]
Correct: [11.375 45.25  71.    80.5   85.625 84.625]%
Cross Entropy Loss: [2.9667509  1.7627857  0.8647931  0.60715836 0.60728353 0.5163029 ]
Correct: [ 9.75  46.375 72.25  80.5   79.75  85.75 ]%
Cross Entropy Loss: [2.9047413  1.8478705  0.80170983 0.5453348  0.5021038  0.47545013]
Correct: [10.    45.5   69.5   80.5   83.25  84.375]%
Cross Entropy Loss: [2.8838997  1.8833306  0.97277486 0.58850855 0.4655931  0.4989623 ]
Correct: [13.5   39.875 68.75  78.625 81.75  82.875]%
Cross Entropy Loss: [2.9431396 1.7508295 0.7993283 0.6839116 0.5515451 0.5776764]
Correct: [14.75  44.375 69.5   79.875 83.125 84.125]%
Cross Entropy Loss: [3.1627584  1.8046802  1.0227302  0.54682213 0.47725177 0.4406547 ]
Correct: [13.    44.    68.125 82.125 85.375 86.375]%
Cross Entropy Loss: [3.0061405  1.7326323  0.8558343  0.55225575 0.4413685  0.42838746]
Correct: [ 9.875 43.    64.875 79.75  85.625 84.625]%
Cross Entropy Loss: [3.0251403  2.1437316  0.87921894 0.5408951  0.494903   0.49094287]
Correct: [11.    43.    67.875 79.5   83.375 83.375]%
Cross Entropy Loss: [2.9322987  1.652333   0.6880803  0.49353772 0.35898796 0.36407328]
Correct: [15.125 45.75  76.25  82.625 85.375 87.375]%
Cross Entropy Loss: [3.200056   1.6237847  0.82318056 0.5221631  0.42583814 0.43888077]
Correct: [11.5   47.875 73.5   81.125 87.    85.5  ]%
Cross Entropy Loss: [3.0273337  1.9252964  0.876787   0.67927283 0.5911812  0.54684985]
Correct: [12.625 40.625 70.    78.    81.75  85.375]%
Cross Entropy Loss: [3.2836616  1.6204703  0.9664984  0.62907064 0.49594408 0.533266  ]
Correct: [10.    46.875 69.    79.    84.125 84.   ]%
Cross Entropy Loss: [3.0609908  1.8892581  1.041799   0.6662888  0.41626057 0.4137256 ]
Correct: [14.75  44.625 67.25  77.75  85.25  88.   ]%
Cross Entropy Loss: [2.8581684  1.8039081  1.0574852  0.60612345 0.51491916 0.54161006]
Correct: [12.125 42.125 65.25  79.    82.5   83.625]%
Cross Entropy Loss: [3.1883678 1.5175365 0.7681623 0.5504075 0.5069863 0.4727601]
Correct: [13.875 45.    71.5   80.125 84.125 85.375]%
Cross Entropy Loss: [4.5461597  1.9413345  1.2279657  0.7244     0.50492215 0.47895068]
Correct: [10.75  41.625 64.    77.625 84.125 84.375]%
Cross Entropy Loss: [2.8848042  2.0539155  0.86577195 0.50719327 0.3743063  0.42951965]
Correct: [12.375 42.375 72.125 82.625 85.75  86.625]%
Cross Entropy Loss: [3.1026292  2.6553998  1.1793092  0.76385534 0.7186745  0.67797405]
Correct: [10.75  39.625 66.75  81.625 82.    84.75 ]%
Cross Entropy Loss: [2.9684005  1.5610086  0.82370996 0.53371507 0.48485833 0.48094493]
Correct: [13.625 46.75  69.25  81.125 83.5   85.75 ]%
Cross Entropy Loss: [3.4059956 2.1697114 1.169297  0.5937111 0.5299756 0.5177182]
Correct: [10.5   40.875 66.25  79.75  83.875 84.625]%
Cross Entropy Loss: [2.981723   1.5343466  0.7402752  0.5267539  0.39910927 0.41904512]
Correct: [12.125 46.125 74.375 82.625 87.    88.   ]%
Cross Entropy Loss: [2.8993185 1.6620448 0.8629214 0.6263448 0.6495975 0.5531028]
Correct: [13.125 43.125 69.125 80.125 82.5   85.875]%
Cross Entropy Loss: [2.8841465  1.6387984  0.6655936  0.46667004 0.40492797 0.3920607 ]
Correct: [10.625 41.375 72.5   82.125 83.875 85.25 ]%
Cross Entropy Loss: [3.0970635  1.6737862  0.87659276 0.5304945  0.52965105 0.4708584 ]
Correct: [12.75  47.    71.375 82.75  83.875 85.625]%
Cross Entropy Loss: [3.2678325  1.5704967  0.83065474 0.5681037  0.5583045  0.5014496 ]
Correct: [14.5   46.125 73.875 81.125 80.875 84.375]%
Cross Entropy Loss: [3.4222984  1.5396487  0.9070364  0.6226528  0.48775855 0.51048726]
Correct: [10.875 45.5   67.5   79.75  84.375 84.375]%
Cross Entropy Loss: [3.081642   1.6114861  1.1220353  0.6021444  0.47716507 0.49586043]
Correct: [11.75  48.125 63.375 80.375 84.5   85.75 ]%
Cross Entropy Loss: [3.061386   1.5051836  0.6441544  0.6589028  0.51302254 0.47163734]
Correct: [ 8.875 43.75  72.375 80.25  85.5   87.25 ]%
Cross Entropy Loss: [2.9992366  1.9130338  1.0048182  0.68723977 0.5332165  0.54111946]
Correct: [12.5   40.375 72.375 81.5   83.5   86.375]%
Cross Entropy Loss: [3.0111508  2.0618596  0.878984   0.5489829  0.48416367 0.50063264]
Correct: [15.375 37.5   67.75  80.5   83.25  84.875]%
Cross Entropy Loss: [3.0650299  1.8617165  0.79462755 0.5597135  0.55789185 0.4707482 ]
Correct: [13.125 48.375 71.375 82.25  82.75  85.625]%
Cross Entropy Loss: [2.86556    1.4245446  0.81697196 0.60661584 0.46801043 0.47327223]
Correct: [13.125 50.5   68.625 80.    84.125 86.625]%
Cross Entropy Loss: [3.5813553  1.7462721  0.6996439  0.45092177 0.3761281  0.4439034 ]
Correct: [12.25  48.125 71.    83.5   85.    85.375]%
Cross Entropy Loss: [2.8774676  1.9664905  1.0307274  0.50789547 0.43523592 0.43112522]
Correct: [13.25  40.25  70.5   82.125 84.5   85.625]%
Cross Entropy Loss: [2.9437509  1.6426048  0.9455849  0.67471564 0.6092092  0.5702437 ]
Correct: [13.125 44.125 68.625 79.875 80.25  83.5  ]%
Cross Entropy Loss: [2.9589424  1.5227692  0.78685045 0.5386065  0.46356648 0.49786806]
Correct: [14.375 45.125 71.125 82.125 85.75  86.5  ]%
Cross Entropy Loss: [3.0250165  1.7240454  0.7977293  0.54554    0.43309125 0.3884979 ]
Correct: [10.625 47.375 73.    82.    84.75  86.75 ]%
Cross Entropy Loss: [3.129272   1.967033   0.9577776  0.58596694 0.49101862 0.5307419 ]
Correct: [15.25  44.125 68.125 79.875 84.25  84.25 ]%
Cross Entropy Loss: [2.9101024 1.7318776 1.1131595 0.5208601 0.4648195 0.4657405]
Correct: [17.25  46.125 67.5   84.375 84.375 85.375]%
Cross Entropy Loss: [3.5960832  1.9333878  0.7571655  0.60292566 0.50618345 0.4628747 ]
Correct: [10.    40.25  70.375 79.125 82.25  85.625]%
Cross Entropy Loss: [3.029155   1.4889362  0.73992825 0.5707727  0.45236874 0.43567744]
Correct: [12.25  48.875 73.    82.125 83.125 87.   ]%
Cross Entropy Loss: [2.8766136  1.6459885  0.84303045 0.56460893 0.3991136  0.41761708]
Correct: [13.875 42.875 69.875 79.625 84.    86.   ]%
Cross Entropy Loss: [3.4214454  1.5972183  0.8658182  0.64118004 0.56899863 0.5759396 ]
Correct: [13.25  50.875 71.625 79.875 83.625 82.625]%
Cross Entropy Loss: [2.9154363  1.7812408  0.7149235  0.5127851  0.459591   0.45135212]
Correct: [10.5   41.375 70.625 81.    86.125 86.625]%
Cross Entropy Loss: [2.9860187  1.6418374  0.7283004  0.47905737 0.4052481  0.3719082 ]
Correct: [10.125 48.5   74.5   82.25  85.375 87.625]%
Cross Entropy Loss: [2.9465632  2.1019638  0.7686553  0.47209206 0.4784643  0.434412  ]
Correct: [10.75  40.5   70.75  82.875 85.125 86.625]%
Cross Entropy Loss: [2.9009266  1.6417015  0.8166645  0.58031374 0.4775436  0.4580237 ]
Correct: [14.625 46.25  67.25  80.375 83.5   84.375]%
Cross Entropy Loss: [2.90346    2.0258083  0.9324237  0.5852523  0.47885126 0.424465  ]
Correct: [10.125 43.75  71.5   79.625 84.5   84.75 ]%
Cross Entropy Loss: [2.9447227  2.3738275  1.2133219  0.5977822  0.43241492 0.36062104]
Correct: [10.25 43.5  67.5  77.5  84.5  87.75]%
In [ ]:
print(f'Accuracy of last quarter: {100*acc[-len(acc)//4:].mean(axis=0)}%')
plt.figure()
plt.plot(np.asarray(acc)*100)
plt.ylabel('Accuracy [%]')
plt.xlabel('Training Step [x500]')
labels = ['From Inputs directly', *[f'From Layer {i+1}' for i in range(len(SNN.clapp))]]
plt.legend(labels)
# plt.ylim([75, 90])
plt.figure()
print(losses_out.shape)
for i in range(losses_out.shape[1]):
    plt.plot(np.arange(len(losses_out))/len(train_loader), savgol_filter(losses_out[:,i], 99, 1), label=labels[i])
plt.ylabel('Cross Entropy Loss')
plt.xlabel('Training Step')
plt.legend();
Accuracy of last quarter: [12.49038462 44.97115385 70.21153846 81.02884615 83.99038462 85.66346154]%
torch.Size([1020, 6])

Get output projection Accuracy on test set¶

In [ ]:
from tqdm.notebook import trange
correct = torch.zeros(len(out_projs))
for out_proj in out_projs:
    out_proj.eval()
SNN.eval()
pred_matrix = torch.zeros(n_outputs, n_outputs)
for idx in trange(0, len(test_loader), batch_size):
    for out_proj in out_projs:
        out_proj.reset()
    SNN.reset(0)
    inp, target = test_loader.x[idx:idx+batch_size], test_loader.y[idx:idx+batch_size]
    logits = len(out_projs)*[torch.zeros((inp.shape[0],20))]
    for step in range(inp.shape[1]):
        data_step = inp[:,step].float().to(device)
        spk_step, _, _ = SNN(data_step, None, 0)
        spk_step = [data_step, *spk_step]
        for i, out_proj in enumerate(out_projs):
            out, _ = out_proj(spk_step[i], target)
            logits[i] = logits[i] + out
    for i, logit in enumerate(logits):
        pred = logit.argmax(axis=-1)
        correct[i] += int((pred == target).sum())
    # for the last layer create the prediction matrix
    for j in range(pred.shape[0]):
        pred_matrix[int(target[j]), int(pred[j])] += 1
correct /= len(test_loader)
print('Directly from inputs:')
print(f'Accuracy: {100*correct[0]:.2f}%')
for i in range(len(out_projs)-1):
    print(f'From layer {i+1}:')
    print(f'Accuracy: {100*correct[i+1]:.2f}%')

plt.imshow(pred_matrix, origin='lower')
plt.title('Prediction Matrix for the final layer')
plt.xlabel('Prediction')
plt.ylabel('Target')
plt.xticks([i for i in range(n_outputs)])
plt.yticks([i for i in range(n_outputs)])
plt.colorbar();
  0%|          | 0/36 [00:00<?, ?it/s]
Directly from inputs:
Accuracy: 12.10%
From layer 1:
Accuracy: 38.21%
From layer 2:
Accuracy: 51.94%
From layer 3:
Accuracy: 61.88%
From layer 4:
Accuracy: 67.14%
From layer 5:
Accuracy: 67.80%

Analyze the final Network¶

In [ ]:
print(torch.diag(pred_matrix).sum()/pred_matrix.sum())
print(pred_matrix.diag().sum(), pred_matrix.sum(), len(test_loader))
from snntorch import spikeplot as spkplt
SNN.eval()
for out_proj in out_projs:
    out_proj.eval()
data_lastlayer = torch.zeros(n_time_bins, 512)
data_out = torch.zeros(n_time_bins, 20)
for idx in range(5):
    for step in range(inp.shape[1]):
        data_step = inp[:,step].float().to(device)
        spk_step, _, _ = SNN(data_step, None, 0)
        data_lastlayer[step] = spk_step[-1][idx]
        out, _ = out_projs[-1](spk_step[-1][idx], target)
        data_out[step] = out[0]

    print(target[idx])
    fig = plt.figure(facecolor="w", figsize=(10, 5))
    ax = fig.add_subplot(111)
    print(data_lastlayer.mean(), data_out.mean())
    spkplt.raster(data_lastlayer, ax, s=1.5, color='black')
    fig = plt.figure(facecolor="w", figsize=(10, 5))
    ax = fig.add_subplot(111)
    spkplt.raster(data_out, ax, s=5, color='black')
tensor(0.6780)
tensor(1535.) tensor(2264.) 2264
tensor(18.)
tensor(0.0378) tensor(0.0255)
tensor(8.)
tensor(0.0652) tensor(0.0240)
tensor(4.)
tensor(0.0674) tensor(0.0235)
tensor(1.)
tensor(0.0632) tensor(0.0210)
tensor(8.)
tensor(0.0472) tensor(0.0215)